package com.hyving.home.mybatismulti.aop;


import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.annotation.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.stereotype.Component;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.HashMap;
import java.util.Map;
import java.util.Stack;

@Component
@Aspect
@Slf4j
public class MultiTransactionAop {

    private static final ThreadLocal<Stack<Map<DataSourceTransactionManager, TransactionStatus>>> THREAD_LOCAL = new ThreadLocal<>();
    /**
     * 用于获取事务管理器
     */
    @Autowired
    private ApplicationContext applicationContext;
    /**
     * 事务声明
     */
    private DefaultTransactionDefinition def = new DefaultTransactionDefinition();
    {
        // 非只读模式
        def.setReadOnly(false);
        // 事务隔离级别：采用数据库的
        def.setIsolationLevel(TransactionDefinition.ISOLATION_DEFAULT);
        // 事务传播行为
        def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
    }
    /**
     * 切点
     */
    @Pointcut("@annotation(com.hyving.home.mybatismulti.aop.MultiDataSourceTransactional)")
    public void pointcut() {
    }
    /**
     * 声明事务
     *
     * @param transactional 注解
     */
    @Before("pointcut() && @annotation(transactional)")
    public void before(MultiDataSourceTransactional transactional) {
        // 根据设置的事务名称按顺序声明，并放到ThreadLocal里
        String[] transactionManagerNames = transactional.transactionManagers();
        Stack<Map<DataSourceTransactionManager, TransactionStatus>> pairStack = new Stack<>();
        for (String transactionManagerName : transactionManagerNames) {
            DataSourceTransactionManager transactionManager = applicationContext.getBean(transactionManagerName, DataSourceTransactionManager.class);
            TransactionStatus transactionStatus = transactionManager.getTransaction(def);
            Map<DataSourceTransactionManager, TransactionStatus> transactionMap = new HashMap<>();
            transactionMap.put(transactionManager, transactionStatus);
            pairStack.push(transactionMap);
        }
        THREAD_LOCAL.set(pairStack);
    }
    /**
     * 后置增强，相当于AfterReturningAdvice，方法退出时执行
     *
     * 提交事务
     */
    @AfterReturning("pointcut()")
    public void afterReturning() {
        // ※栈顶弹出（后进先出）
        Stack<Map<DataSourceTransactionManager, TransactionStatus>> pairStack = THREAD_LOCAL.get();
        while (!pairStack.empty()) {
            Map<DataSourceTransactionManager, TransactionStatus> pair = pairStack.pop();
            pair.forEach((key,value)->key.commit(value));
        }
        THREAD_LOCAL.remove();
    }
    /**
     * 异常抛出增强，相当于ThrowsAdvice
     *
     * 回滚事务
     */
    @AfterThrowing(value = "pointcut()")
    public void afterThrowing() {
        // ※栈顶弹出（后进先出）
        Stack<Map<DataSourceTransactionManager, TransactionStatus>> pairStack = THREAD_LOCAL.get();
        log.info("=========================");
        log.info("Pair Stack:{}", pairStack);
        log.info("=========================");
        while (!pairStack.empty()) {
            Map<DataSourceTransactionManager, TransactionStatus> pair = pairStack.pop();
            pair.forEach((key,value)->key.rollback(value));
        }
        THREAD_LOCAL.remove();
    }
}
